from layers.custom_attn_processor import CustomAttnProcessor2_0
from layers.custom_resnet_block import CustomResnetBlock2D

attn_processor = CustomAttnProcessor2_0


def set_custom_invert_mode(diffusion_models, mode="invert"):
    """
    This function sets the custom attention processor and custom resnet as invert or generation mode.
    """

    for module in diffusion_models:
        if isinstance_str(module, "CustomResnetBlock2D"):
            module.set_invert_or_generate(mode)
        elif isinstance_str(module, "Attention"):
            module.get_processor().set_invert_or_generate(mode)


def set_custom_load_mode(diffusion_models, load_feature=True, load_attn=True):
    """
    This function sets the custom attention processor and custom resnet as load mode or not.
    """

    for module in diffusion_models:
        if isinstance_str(module, "CustomResnetBlock2D"):
            module.set_load_mode(load_feature)
        elif isinstance_str(module, "Attention"):
            module.get_processor().set_load_mode(load_attn)

def reset_custom_cache(diffusion_models):
    """
    This function resets the cache of the custom attention processor and custom resnet.
    """

    for module in diffusion_models:
        if isinstance_str(module, "CustomResnetBlock2D"):
            module.reset_cache()
        elif isinstance_str(module, "Attention"):
            module.get_processor().reset_cache()


def set_custom_cur_t(diffusion_models, cur_t):
    """
    This function sets the current time step for the custom attention processor and custom resnet.
    """

    for module in diffusion_models:
        if isinstance_str(module, "CustomResnetBlock2D"):
            module.set_cur_t(cur_t)
        elif isinstance_str(module, "Attention"):
            module.get_processor().set_cur_t(cur_t)


def save_feature_and_attn(diffusion_models):
    """
    This function sets the current time step for the custom attention processor.
    """

    for module in diffusion_models:
        if isinstance_str(module, "CustomResnetBlock2D"):
            module.save_cache_to_file()
        elif isinstance_str(module, "Attention"):
            module.get_processor().save_cache_to_file()


def load_feature_and_attn(diffusion_models):
    """
    This function sets the current time step for the custom attention processor.
    """

    for module in diffusion_models:
        if isinstance_str(module, "CustomResnetBlock2D"):
            module.load_cache_from_file()
        elif isinstance_str(module, "Attention"):
            module.get_processor().load_cache_from_file()


def custom_module_register(pipe, config):
    """
    This function registers the custom attention processor.
    """

    if config.generation.cache.enabled:
        cache_down_attn_dict = config.generation.cache.down_attn_dict
        cache_mid_attn_dict = config.generation.cache.mid_attn_dict
        cache_up_attn_dict = config.generation.cache.up_attn_dict

        cache_down_res_dict = config.generation.cache.down_res_dict
        cache_up_res_dict = config.generation.cache.up_res_dict

        for block_num in cache_down_res_dict.keys():
            for resnet_num in cache_down_res_dict[block_num]:
                pipe.unet.down_blocks[block_num].resnets[resnet_num] = (
                    set_custom_resnet(pipe.unet.down_blocks[block_num].resnets[resnet_num]))

        for block_num in cache_up_res_dict.keys():
            for resnet_num in cache_up_res_dict[block_num]:
                pipe.unet.up_blocks[block_num].resnets[resnet_num] = (
                    set_custom_resnet(pipe.unet.up_blocks[block_num].resnets[resnet_num]))

        custom_attn_modules_register(attn_dict=cache_down_attn_dict,
                                     model=pipe.unet.down_blocks,
                                     use_cache=config.generation.cache.enabled,
                                     )
        custom_attn_modules_register(attn_dict=cache_up_attn_dict,
                                     model=pipe.unet.up_blocks,
                                     use_cache=config.generation.cache.enabled,
                                     )
        custom_attn_modules_register(attn_dict=cache_mid_attn_dict,
                                     model=pipe.unet.mid_block,
                                     use_cache=config.generation.cache.enabled,
                                     is_mid=True)


def set_custom_resnet(module, use_cache=True):
    """
    This function replace resnet in unet as the custom resnet module for PNP.
    :param module:
    :param num:
    :return:
    """
    custom_resnet_module = CustomResnetBlock2D(in_channels=module.in_channels,
                                               out_channels=module.out_channels,
                                               conv_shortcut=module.use_conv_shortcut,
                                               dropout=module.dropout.p,
                                               temb_channels=module.time_emb_proj.in_features,
                                               groups=module.norm1.num_groups,
                                               groups_out=module.norm2.num_groups,
                                               pre_norm=module.pre_norm,
                                               eps=module.norm1.eps,
                                               skip_time_act=module.skip_time_act,
                                               time_embedding_norm=module.time_embedding_norm,
                                               output_scale_factor=module.output_scale_factor,
                                               use_in_shortcut=module.use_in_shortcut,
                                               up=module.up,
                                               down=module.down,
                                               conv_2d_out_channels=module.conv2.out_channels,
                                               use_cache=use_cache,
                                               )
    custom_resnet_module.norm1 = module.norm1
    custom_resnet_module.conv1 = module.conv1
    custom_resnet_module.time_emb_proj = module.time_emb_proj
    custom_resnet_module.norm2 = module.norm2
    custom_resnet_module.dropout = module.dropout
    custom_resnet_module.conv2 = module.conv2
    custom_resnet_module.nonlinearity = module.nonlinearity
    custom_resnet_module.conv_shortcut = module.conv_shortcut
    if module.up: custom_resnet_module.upsample = module.upsample
    if module.down: custom_resnet_module.downsample = module.downsample
    return custom_resnet_module


def custom_attn_modules_register(attn_dict, model, use_cache, is_mid=False):
    for block_num in attn_dict.keys():
        for attn_num in attn_dict[block_num]:
            if is_mid:
                model.attentions[attn_num].transformer_blocks[0].attn1.set_processor (
                    attn_processor(
                        load_attn=False,
                        use_cache=use_cache)
                )
            else:
                model[block_num].attentions[attn_num].transformer_blocks[0].attn1.set_processor(
                    attn_processor(
                        load_attn=False,
                        use_cache=use_cache)
                )


def get_custom_modules(model, config):
    modules = []

    cache_down_attn_dict = config.generation.cache.down_attn_dict
    cache_mid_attn_dict = config.generation.cache.mid_attn_dict
    cache_up_attn_dict = config.generation.cache.up_attn_dict

    cache_down_res_dict = config.generation.cache.down_res_dict
    cache_up_res_dict = config.generation.cache.up_res_dict

    for block_num in cache_up_res_dict.keys():
        for resnet_num in cache_up_res_dict[block_num]:
            modules.append(model.up_blocks[block_num].resnets[resnet_num])

    for block_num in cache_down_res_dict.keys():
        for resnet_num in cache_down_res_dict[block_num]:
            modules.append(model.down_blocks[block_num].resnets[resnet_num])

    for down_block_num in cache_down_attn_dict.keys():
        for down_attn_num in cache_down_attn_dict[down_block_num]:
            if model.down_blocks[down_block_num].attentions[down_attn_num].transformer_blocks[
                0].attn1 not in modules:
                modules.append(
                    model.down_blocks[down_block_num].attentions[down_attn_num].transformer_blocks[0].attn1)

    for up_block_num in cache_up_attn_dict.keys():
        for up_attn_num in cache_up_attn_dict[up_block_num]:
            if model.up_blocks[up_block_num].attentions[up_attn_num].transformer_blocks[0].attn1 not in modules:
                modules.append(model.up_blocks[up_block_num].attentions[up_attn_num].transformer_blocks[0].attn1)

    for mid_block_num in cache_mid_attn_dict.keys():
        for mid_attn_num in cache_mid_attn_dict[mid_block_num]:
            if model.mid_block.attentions[mid_attn_num].transformer_blocks[0].attn1 not in modules:
                modules.append(model.mid_block.attentions[mid_attn_num].transformer_blocks[0].attn1)

    return modules


def isinstance_str(x: object, cls_name: str):
    """
    Checks whether x has any class *named* cls_name in its ancestry.
    Doesn't require access to the class's implementation.

    Useful for patching!
    """

    for _cls in x.__class__.__mro__:
        if _cls.__name__ == cls_name:
            return True

    return False
